from typing import List, Dict, Optional
from pathlib import Path
import json
import random
import time
from tqdm import tqdm
from openai import OpenAI
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class Config:
    SAMPLE_SIZE = 300
    API_RATE_LIMIT = 2
    MAX_RETRIES = 3
    
    BASE_PATH = Path("/process_COT/gsm8k")
    INPUT_FILE = BASE_PATH / "gsm8k.json"
    OUTPUT_FILE = BASE_PATH / "reasoning_output_gsm8k_correct.txt"
    PROGRESS_FILE = BASE_PATH / "progress.json"
    
    API_KEY = ""
    BASE_URL = ""
    MODEL_NAME = "gpt-4o"

class QuestionProcessor:
    def __init__(self):
        self.client = OpenAI(
            api_key=Config.API_KEY,
            base_url=Config.BASE_URL
        )
        self.processed_questions = self._load_progress()
        self.examples = self._load_examples()
    
    def _load_progress(self) -> set:
        if Config.PROGRESS_FILE.exists():
            with open(Config.PROGRESS_FILE, 'r') as f:
                return set(json.load(f))
        return set()

    def _save_progress(self, question: str):
        self.processed_questions.add(question)
        with open(Config.PROGRESS_FILE, 'w') as f:
            json.dump(list(self.processed_questions), f)

    def _load_examples(self) -> str:
        example_path = "/gsm8k/cot_8_clean.txt"
        with open(example_path, 'r', encoding='utf-8') as f:
            return f.read().strip()

    def get_completion(self, prompt: str, retries: int = Config.MAX_RETRIES) -> Optional[str]:
        for attempt in range(retries):
            try:
                response = self.client.chat.completions.create(
                    model=Config.MODEL_NAME,
                    messages=[
                        {"role": "system", "content": """You are a math teacher who explains solutions step by step.
Follow the examples exactly - use simple language, show each calculation clearly, and format the answer the same way."""},
                        {"role": "user", "content": prompt}
                    ]
                )
                return response.choices[0].message.content.strip()
            except Exception as e:
                logging.warning(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                if attempt < retries - 1:
                    time.sleep(2 ** attempt)  # Exponential backoff
                else:
                    logging.error(f"API call final failure: {str(e)}")
                    return None

    def format_question(self, qa_item: Dict) -> str:
        answer_parts = qa_item['answer'].split('####')
        reasoning = answer_parts[0].strip()
        final_answer = answer_parts[1].strip() if len(answer_parts) > 1 else ""
        
        return f"""Question: {qa_item['question']}
Original Solution Steps:
{reasoning}
Final Answer: {final_answer}"""

    def generate_prompt(self, formatted_q: str) -> str:
        examples = self.examples[:500]
        return f"""Solve this math problem exactly following this format:
1. Start with "Question: " followed by the complete problem statement
2. Write your solution with clear steps and calculations
3. End with "The answer is [number]"

Here are some examples of the format:

{examples}

Now solve this problem:
{formatted_q}"""

    def format_response(self, response: str, original_question: str) -> str:
        lines = []
        
        # Ensure it starts with a question
        if not response.strip().startswith("Question:"):
            lines.append(f"Question: {original_question}")
        
        for line in response.split('\n'):
            line = line.strip()
            if not line or any(x in line.lower() for x in ['original solution', 'final answer']):
                continue
            
            # If it's a question line but not the one we added, skip it
            if line.startswith("Question:") and lines and lines[0].startswith("Question:"):
                continue
                
            # Process answer line
            if line.startswith("The answer is"):
                number = ''.join(c for c in line if c.isdigit() or c == '.')
                line = f"The answer is {number}"
            
            lines.append(line)
            
        # Process complete response
        response_text = '\n'.join(lines)
        
        # Ensure two blank lines between questions
        return response_text + '\n\n\n'

    def process_questions(self, questions: List[Dict]):
        Config.OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
        
        # Only process first 300 data items
        selected_questions = questions[:Config.SAMPLE_SIZE]
        
        with open(Config.OUTPUT_FILE, 'w', encoding='utf-8') as f:  # Use 'w' mode
            for qa in tqdm(selected_questions, desc=f"Processing first {Config.SAMPLE_SIZE} questions"):
                if qa['question'] in self.processed_questions:
                    logging.info(f"Skipping already processed question: {qa['question']}")
                    continue
                
                formatted_q = self.format_question(qa)
                prompt = self.generate_prompt(formatted_q)
                
                if response := self.get_completion(prompt):
                    formatted_response = self.format_response(response, qa['question'])
                    f.write(formatted_response)  # No need to add extra newlines
                    self._save_progress(qa['question'])
                    
                time.sleep(Config.API_RATE_LIMIT)

def main():
    try:
        # Read question data
        with open(Config.INPUT_FILE, 'r', encoding='utf-8') as f:
            try:
                questions = json.load(f)
                logging.info(f"loaded {len(questions)} questions")
            except json.JSONDecodeError as e:
                logging.error(f"JSON parsing error: {str(e)}")
                return
        
        # Initialize processor and process questions
        processor = QuestionProcessor()
        processor.process_questions(questions)
        
        logging.info("Processing completed!")
        
    except Exception as e:
        logging.error(f"Program error: {str(e)}")

if __name__ == "__main__":
    main()